{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Survival analysis with Catboost"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "Unlike common supervised regression task where target variable is known and observed during the entire period in training dataset, survival regression has some limitations in terms of partially observed or **censored** target variable.  \n",
    "**Censoring** is a form of missing data problem in which time to event is not observed for reasons such as termination of study before all recruited subjects have shown the event of interest or the subject has left the study prior to experiencing an event.  \n",
    "  \n",
    "Censoring types:\n",
    " * **Right-censored:** When the patient is lost to followup or the event does not occur within the study duration. \n",
    " * **Left-censored:** When the patient had been on risk for disease for a period before entering the study.\n",
    " * **Interval-censored:** When time to event may be known only up to a time interval.\n",
    " * **Uncensored:** When everything is alright:)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Cox proportional model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "Survival models of this type can be viewed as consisting of two parts: the underlying <a href=\"https://en.wikipedia.org/wiki/Hazard_function\"> baseline hazard </a> function, often denoted ${\\displaystyle \\lambda _{0}(t)}$, describing how the risk of event per time unit changes over time at baseline levels of covariates; and the effect parameters, describing how the hazard varies in response to explanatory covariates. \n",
    "\n",
    "Let $X_i=(X_{i1},...,X_{ip})$ be the realized values of the covariates for subject i. The hazard function for the Cox proportional hazards model has the form:\n",
    "\n",
    "$$\\lambda(t|X_i)=\\lambda_0(t)exp(h(X_i)),$$\n",
    "where $h(X_i)$ is the hazard rate for considered sample $X_i$, represented as CatBoost model.This expression gives the hazard function at time $t$ for subject $i$ with covariate vector (explanatory variables) $X_i$.The likelihood of the event to be observed occurring for subject $i$ at time $Y_i$ can be written as:\n",
    "\n",
    "$$L_i(\\beta)=\n",
    "    \\frac{\\lambda(Y_i|X_i)}{\\sum_{j: Y_j \\ge Y_i}\\lambda(Y_i|X_j)} =\n",
    "    \\frac{\\lambda_0(Y_i)exp(h(X_i))}{\\sum_{j: Y_j \\ge Y_i}\\lambda_0(Y_i)exp(h(X_j))} =\n",
    "    \\frac{exp(h(X_i))}{\\sum_{j: Y_j \\ge Y_i}exp(h(X_j))}\n",
    "$$\n",
    "\n",
    "Treating the subjects as if they were statistically independent of each other, the joint probability of all realized events is the following partial likelihood, where the occurrence of the event is indicated by $C_i = 1$:\n",
    "\n",
    "$${\\displaystyle L(\\beta )=\\prod _{i:C_{i}=1}L_{i}(\\beta ).}$$\n",
    "\n",
    "The corresponding log partial likelihood is:\n",
    "\n",
    "$$\\ell(\\beta)=\\sum_{i:C_i=1}(h(X_i) - log \\sum_{j:Y_j \\ge Y_i} exp(h(X_j)))$$\n",
    "\n",
    "This is the function we are trying to optimize via CatBoost."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "### Fit model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sksurv.datasets import load_flchain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, target = load_flchain()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>chapter</th>\n",
       "      <th>creatinine</th>\n",
       "      <th>flc.grp</th>\n",
       "      <th>kappa</th>\n",
       "      <th>lambda</th>\n",
       "      <th>mgus</th>\n",
       "      <th>sample.yr</th>\n",
       "      <th>sex</th>\n",
       "      <th>death</th>\n",
       "      <th>futime</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>97.0</td>\n",
       "      <td>Circulatory</td>\n",
       "      <td>1.7</td>\n",
       "      <td>10</td>\n",
       "      <td>5.700</td>\n",
       "      <td>4.860</td>\n",
       "      <td>no</td>\n",
       "      <td>1997</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>85.0</td>\n",
       "      <td>85.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>92.0</td>\n",
       "      <td>Neoplasms</td>\n",
       "      <td>0.9</td>\n",
       "      <td>1</td>\n",
       "      <td>0.870</td>\n",
       "      <td>0.683</td>\n",
       "      <td>no</td>\n",
       "      <td>2000</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>1281.0</td>\n",
       "      <td>1281.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>94.0</td>\n",
       "      <td>Circulatory</td>\n",
       "      <td>1.4</td>\n",
       "      <td>10</td>\n",
       "      <td>4.360</td>\n",
       "      <td>3.850</td>\n",
       "      <td>no</td>\n",
       "      <td>1997</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>69.0</td>\n",
       "      <td>69.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>92.0</td>\n",
       "      <td>Circulatory</td>\n",
       "      <td>1.0</td>\n",
       "      <td>9</td>\n",
       "      <td>2.420</td>\n",
       "      <td>2.220</td>\n",
       "      <td>no</td>\n",
       "      <td>1996</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>115.0</td>\n",
       "      <td>115.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>93.0</td>\n",
       "      <td>Circulatory</td>\n",
       "      <td>1.1</td>\n",
       "      <td>6</td>\n",
       "      <td>1.320</td>\n",
       "      <td>1.690</td>\n",
       "      <td>no</td>\n",
       "      <td>1996</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>1039.0</td>\n",
       "      <td>1039.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>90.0</td>\n",
       "      <td>Mental</td>\n",
       "      <td>1.0</td>\n",
       "      <td>9</td>\n",
       "      <td>2.010</td>\n",
       "      <td>1.860</td>\n",
       "      <td>no</td>\n",
       "      <td>1997</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>1355.0</td>\n",
       "      <td>1355.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>90.0</td>\n",
       "      <td>Mental</td>\n",
       "      <td>0.8</td>\n",
       "      <td>1</td>\n",
       "      <td>0.430</td>\n",
       "      <td>0.880</td>\n",
       "      <td>no</td>\n",
       "      <td>1996</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>2851.0</td>\n",
       "      <td>2851.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>90.0</td>\n",
       "      <td>Nervous</td>\n",
       "      <td>1.2</td>\n",
       "      <td>10</td>\n",
       "      <td>2.470</td>\n",
       "      <td>2.700</td>\n",
       "      <td>no</td>\n",
       "      <td>1999</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>372.0</td>\n",
       "      <td>372.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>93.0</td>\n",
       "      <td>Respiratory</td>\n",
       "      <td>1.2</td>\n",
       "      <td>9</td>\n",
       "      <td>1.910</td>\n",
       "      <td>2.180</td>\n",
       "      <td>no</td>\n",
       "      <td>1996</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>3309.0</td>\n",
       "      <td>3309.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>91.0</td>\n",
       "      <td>Circulatory</td>\n",
       "      <td>0.8</td>\n",
       "      <td>6</td>\n",
       "      <td>0.791</td>\n",
       "      <td>2.220</td>\n",
       "      <td>no</td>\n",
       "      <td>1996</td>\n",
       "      <td>F</td>\n",
       "      <td>True</td>\n",
       "      <td>1326.0</td>\n",
       "      <td>1326.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    age      chapter  creatinine flc.grp  kappa  lambda mgus sample.yr sex  \\\n",
       "0  97.0  Circulatory         1.7      10  5.700   4.860   no      1997   F   \n",
       "1  92.0    Neoplasms         0.9       1  0.870   0.683   no      2000   F   \n",
       "2  94.0  Circulatory         1.4      10  4.360   3.850   no      1997   F   \n",
       "3  92.0  Circulatory         1.0       9  2.420   2.220   no      1996   F   \n",
       "4  93.0  Circulatory         1.1       6  1.320   1.690   no      1996   F   \n",
       "5  90.0       Mental         1.0       9  2.010   1.860   no      1997   F   \n",
       "6  90.0       Mental         0.8       1  0.430   0.880   no      1996   F   \n",
       "7  90.0      Nervous         1.2      10  2.470   2.700   no      1999   F   \n",
       "8  93.0  Respiratory         1.2       9  1.910   2.180   no      1996   F   \n",
       "9  91.0  Circulatory         0.8       6  0.791   2.220   no      1996   F   \n",
       "\n",
       "   death  futime   label  \n",
       "0   True    85.0    85.0  \n",
       "1   True  1281.0  1281.0  \n",
       "2   True    69.0    69.0  \n",
       "3   True   115.0   115.0  \n",
       "4   True  1039.0  1039.0  \n",
       "5   True  1355.0  1355.0  \n",
       "6   True  2851.0  2851.0  \n",
       "7   True   372.0   372.0  \n",
       "8   True  3309.0  3309.0  \n",
       "9   True  1326.0  1326.0  "
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target = pd.DataFrame(target)\n",
    "target['label'] = np.where(target['death'], target['futime'], - target['futime'])\n",
    "\n",
    "data = data.join(target)\n",
    "data.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "features = ['age', 'chapter', 'creatinine', 'flc.grp', 'kappa', 'lambda', 'mgus', 'sample.yr', 'sex']\n",
    "cat_features = ['sex', 'mgus', 'chapter', 'flc.grp', 'sample.yr']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "from catboost import CatBoostRegressor, Pool\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "vals = train.mode().iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = train.fillna(vals)\n",
    "test = test.fillna(vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pool = Pool(train[features], label=train['label'], cat_features=cat_features)\n",
    "test_pool = Pool(test[features], label=test['label'], cat_features=cat_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CatBoostRegressor(iterations=100,\n",
    "                          loss_function='Cox',\n",
    "                          eval_metric='Cox',\n",
    "                          metric_period=50\n",
    "                         )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:\tlearn: 13717.0120433\ttest: 3905.4239773\tbest: 3905.4239773 (0)\ttotal: 6.79ms\tremaining: 672ms\n",
      "50:\tlearn: 13384.7022346\ttest: 3804.2275700\tbest: 3804.2275700 (50)\ttotal: 249ms\tremaining: 239ms\n",
      "99:\tlearn: 13124.4787555\ttest: 3723.5118027\tbest: 3723.5118027 (99)\ttotal: 475ms\tremaining: 0us\n",
      "\n",
      "bestTest = 3723.511803\n",
      "bestIteration = 99\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<catboost.core.CatBoostRegressor at 0x7fba70eb4da0>"
      ]
     },
     "execution_count": 173,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_pool, eval_set=test_pool)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate metric "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install lifelines\n",
    "from lifelines.utils import concordance_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set. Concordance index:0.84141\n",
      "Test set. Concordance index:0.83833\n"
     ]
    }
   ],
   "source": [
    "train_y_pred = model.predict(train_pool)\n",
    "test_y_pred = model.predict(test_pool)\n",
    "\n",
    "ci_train = concordance_index(train['futime'], -train_y_pred, train['death'])\n",
    "ci_test = concordance_index(test['futime'], -test_y_pred, test['death'])\n",
    "print(f'Train set. Concordance index:{ci_train:0.5f}')\n",
    "print(f'Test set. Concordance index:{ci_test:0.5f}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "### Accelerated Failure Time (AFT) model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "Base parametric model  \n",
    "$$\\log T_i = \\langle w, x_i \\rangle + \\epsilon_i$$  \n",
    "$T_i$ - survival time of the i-th subject\n",
    "$w$ - learnable parameters  \n",
    "$x_i$ - subject time-invariant feature description  \n",
    "$\\epsilon_i$ - error term with. $\\epsilon_i \\sim \\sigma Z_i$ where $Z_i$ is a prespecified a distribution and $\\sigma$ is a scaling hyperparamater.\n",
    "\n",
    "Idea - replace linear term $\\langle w, x \\rangle$ with boosted decision trees classifier $H(x)$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "### Fitting process "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "As $T_i$ is random variable, we can utilize common approach - maximize likelihood function across $n$ observed subjects $$L= \\prod_{i=1}^n f_T(t=T_i) \\sim \\sum_{i=1}^n \\log f_T(t=T_i)$$ where $f_T(t)$ is a PDF of $T = g(\\epsilon) = e^{H(x)+\\epsilon}$ prespecified distribution's PDF.  \n",
    "With an eye on interval form of censored target, likelihood function can be decomposed into two parts:   \n",
    "a) likelihood for uncensored data  \n",
    "b) likelihood for arbitrary censored data  \n",
    "$$\\log L = \\sum_{i=1}^n \\log f_T(t=T_i) = \\sum_{a} \\log f_T(t=T_a) + \\sum_{b} \\log (F_T(t=\\overline{T_b})-F_T(t=\\underline{T_{b}}))$$\n",
    "where $F_T(t)$ is a prespecified distribution's CDF  \n",
    "$\\overline{T}, \\underline{T}$ - upper and lower bounds of the interval  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "In Catboost Z can be distributed according to one of three distributions:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "| Distribution | CDF | PDF |\n",
    "| --- | --- | --- |\n",
    "| Normal |$$\\dfrac{1}{2}(1+erf( \\dfrac{z}{\\sqrt{2}})) $$ | $$\\dfrac{e^{\\dfrac{-z^2}{2}}}{\\sqrt{2\\pi}}$$|\n",
    "| Logistic | $$\\dfrac{e^z}{1+e^z}$$| $$\\dfrac{e^z}{(1+e^z)^2}$$ |\n",
    "| Extreme | $$1-e^{-e^z}$$  | $$e^ze^{-e^z}$$  |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "In order to use censored data in Catboost one should express target variable in interval form:\n",
    " * **Right-censored:** $[ A, +\\infty ]$\n",
    " * **Left-censored:** $[0, B]$\n",
    " * **Interval-censored:** $[A, B]$\n",
    " * **Uncensored:** $[A, A]$  \n",
    "where $A$, $B$ - finite non-negative numbers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "### Fit model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "#### Load dataset "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are going to use [WHAS500* dataset](https://github.com/sebp/scikit-survival/tree/master/sksurv/datasets/data) and will try to predict patients' lifetime after the hospital admission.  \n",
    "*Table 1.2\n",
    "of Hosmer, D.W. and Lemeshow, S. and May, S. (2008) Applied Survival\n",
    "Analysis: Regression Modeling of Time to Event Data: Second Edition,\n",
    "John Wiley and Sons Inc., New York, NY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from catboost import CatBoostRegressor, Pool\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sksurv.datasets import load_whas500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, target = load_whas500()\n",
    "data = data.join(pd.DataFrame(target))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "In order to use right-censored data in catboost use -1 to express $+\\infty$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "data['y_lower'] = data['lenfol']\n",
    "data['y_upper'] = np.where(data['fstat'], data['lenfol'], -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "stratifying_column = data['fstat']\n",
    "data = data.drop(['fstat','lenfol'],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "train, test = train_test_split(data, test_size=0.2, stratify=stratifying_column, random_state=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>afb</th>\n",
       "      <th>age</th>\n",
       "      <th>av3</th>\n",
       "      <th>bmi</th>\n",
       "      <th>chf</th>\n",
       "      <th>cvd</th>\n",
       "      <th>diasbp</th>\n",
       "      <th>gender</th>\n",
       "      <th>hr</th>\n",
       "      <th>los</th>\n",
       "      <th>miord</th>\n",
       "      <th>mitype</th>\n",
       "      <th>sho</th>\n",
       "      <th>sysbp</th>\n",
       "      <th>y_lower</th>\n",
       "      <th>y_upper</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>83.0</td>\n",
       "      <td>0</td>\n",
       "      <td>25.54051</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>78.0</td>\n",
       "      <td>0</td>\n",
       "      <td>89.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>152.0</td>\n",
       "      <td>2178.0</td>\n",
       "      <td>-1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>49.0</td>\n",
       "      <td>0</td>\n",
       "      <td>24.02398</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>60.0</td>\n",
       "      <td>0</td>\n",
       "      <td>84.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>120.0</td>\n",
       "      <td>2172.0</td>\n",
       "      <td>-1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>70.0</td>\n",
       "      <td>0</td>\n",
       "      <td>22.14290</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>88.0</td>\n",
       "      <td>1</td>\n",
       "      <td>83.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>147.0</td>\n",
       "      <td>2190.0</td>\n",
       "      <td>-1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>70.0</td>\n",
       "      <td>0</td>\n",
       "      <td>26.63187</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>76.0</td>\n",
       "      <td>0</td>\n",
       "      <td>65.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>123.0</td>\n",
       "      <td>297.0</td>\n",
       "      <td>297.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>70.0</td>\n",
       "      <td>0</td>\n",
       "      <td>24.41255</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>85.0</td>\n",
       "      <td>0</td>\n",
       "      <td>63.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>135.0</td>\n",
       "      <td>2131.0</td>\n",
       "      <td>-1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  afb   age av3       bmi chf cvd  diasbp gender    hr   los miord mitype sho  \\\n",
       "0   1  83.0   0  25.54051   0   1    78.0      0  89.0   5.0     1      0   0   \n",
       "1   0  49.0   0  24.02398   0   1    60.0      0  84.0   5.0     0      1   0   \n",
       "2   0  70.0   0  22.14290   0   0    88.0      1  83.0   5.0     0      1   0   \n",
       "3   0  70.0   0  26.63187   1   1    76.0      0  65.0  10.0     0      1   0   \n",
       "4   0  70.0   0  24.41255   0   1    85.0      0  63.0   6.0     0      1   0   \n",
       "\n",
       "   sysbp  y_lower  y_upper  \n",
       "0  152.0   2178.0     -1.0  \n",
       "1  120.0   2172.0     -1.0  \n",
       "2  147.0   2190.0     -1.0  \n",
       "3  123.0    297.0    297.0  \n",
       "4  135.0   2131.0     -1.0  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "features = data.columns.difference(['y_lower', 'y_upper'], sort=False)\n",
    "cat_features = ['cvd','afb','sho','chf','av3', 'miord', 'mitype', 'gender']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "train_pool = Pool(train[features], label=train.loc[:,['y_lower','y_upper']], cat_features=cat_features)\n",
    "test_pool = Pool(test[features], label=test.loc[:,['y_lower','y_upper']], cat_features=cat_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "loss function format: SurvivalAft:dist={distname};scale={int}  \n",
    "Hyperparameters:\n",
    "* distname $\\in$ [Normal, Logistic, Extreme]\n",
    "* scale - any positive number representing distribution scale\n",
    "\n",
    "eval metric: SurvivalAft   \n",
    "Metric is calculated as \"interval MAE\":  \n",
    "$$\\text{Interval MAE} = \\begin{cases}\n",
    "0 & \\text{if $\\hat{y} \\in [\\underline{y}, \\overline{y}]$}\\\\\n",
    "\\min{(|\\hat{y}-\\underline{y}|,|\\hat{y}-\\overline{y}|)} & \\text{if $\\hat{y} \\notin [\\underline{y}, \\overline{y}]$}\n",
    "\\end{cases}$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "model_normal = CatBoostRegressor(iterations=500,\n",
    "                                 loss_function='SurvivalAft:dist=Normal',\n",
    "                                 eval_metric='SurvivalAft',\n",
    "                                 verbose=0\n",
    "                                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "model_logistic = CatBoostRegressor(iterations=500,\n",
    "                                 loss_function='SurvivalAft:dist=Logistic;scale=1.2',\n",
    "                                 eval_metric='SurvivalAft',\n",
    "                                 verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "model_extreme = CatBoostRegressor(iterations=500,\n",
    "                                 loss_function='SurvivalAft:dist=Extreme;scale=2',\n",
    "                                 eval_metric='SurvivalAft',\n",
    "                                 verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "_ = model_normal.fit(train_pool, eval_set=test_pool)\n",
    "_ = model_logistic.fit(train_pool, eval_set=test_pool)\n",
    "_ = model_extreme.fit(train_pool, eval_set=test_pool)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_predictions = pd.DataFrame({'y_lower': train['y_lower'],\n",
    "                                  'y_upper': train['y_upper'],\n",
    "                                  'preds_normal': model_normal.predict(train_pool, prediction_type='Exponent'),\n",
    "                                  'preds_logistic': model_logistic.predict(train_pool, prediction_type='Exponent'),\n",
    "                                  'preds_extreme': model_extreme.predict(train_pool, prediction_type='Exponent')})\n",
    "train_predictions['y_upper'] = np.where(train_predictions['y_upper']==-1, np.inf, train_predictions['y_upper'])\n",
    "\n",
    "test_predictions = pd.DataFrame({'y_lower': test['y_lower'],\n",
    "                                  'y_upper': test['y_upper'],\n",
    "                                  'preds_normal': model_normal.predict(test_pool, prediction_type='Exponent'),\n",
    "                                  'preds_logistic': model_logistic.predict(test_pool, prediction_type='Exponent'),\n",
    "                                  'preds_extreme': model_extreme.predict(test_pool, prediction_type='Exponent')})\n",
    "test_predictions['y_upper'] = np.where(test_predictions['y_upper']==-1, np.inf, test_predictions['y_upper'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate metric "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def interval_mae(y_true_lower, y_true_upper, y_pred):\n",
    "    mae = np.where((y_true_lower <= y_pred) & (y_pred <= y_true_upper),\n",
    "                   0,\n",
    "                   np.minimum(np.abs(y_true_lower-y_pred),\n",
    "                              np.abs(y_true_upper-y_pred))) \n",
    "    return mae.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Interval MAE\n",
      "Train set. dist:normal: 197.36\n",
      "Test set. dist:normal: 348.89\n",
      "---------------------------\n",
      "Train set. dist:logistic: 401.47\n",
      "Test set. dist:logistic: 477.54\n",
      "---------------------------\n",
      "Train set. dist:extreme: 324.11\n",
      "Test set. dist:extreme: 307.46\n",
      "---------------------------\n"
     ]
    }
   ],
   "source": [
    "distributions = ['normal', 'logistic', 'extreme']\n",
    "print('Interval MAE')\n",
    "for dist in distributions:\n",
    "    train_metric = interval_mae(train_predictions['y_lower'], train_predictions['y_upper'], train_predictions[f'preds_{dist}'])\n",
    "    test_metric = interval_mae(test_predictions['y_lower'], test_predictions['y_upper'], test_predictions[f'preds_{dist}'])\n",
    "    print(f'Train set. dist:{dist}: {train_metric:0.2f}')\n",
    "    print(f'Test set. dist:{dist}: {test_metric:0.2f}')\n",
    "    print('---------------------------')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
